#!/usr/bin/env python3

import argparse
import json

import deepspeed
from deepspeed.elasticity import compute_elastic_config

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
    parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size")
    args = parser.parse_args()
    ds_config = json.load(open(args.config, 'r'))

    ds_version = deepspeed.__version__

    elastic_config = ds_config['elasticity']
    print('------------------------------------------')
    print("Elasticity config:")
    print('------------------------------------------')
    print(json.dumps(elastic_config, indent=4, sort_keys=True))

    if args.world_size > 0:
        final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config,
                                                                                target_deepspeed_version=ds_version,
                                                                                world_size=args.world_size)
        print('------------------------------------------')
        print(f"Calculated results for world size {args.world_size}:")
        print('------------------------------------------')
        print(f'final_batch_size .... {final_batch_size}')
        print(f'valid_gpus .......... {valid_gpus}')
        print(f'micro_batch_size .... {micro_batch_size}')
    else:
        final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
        print('------------------------------------------')
        print("Calculated results:")
        print('------------------------------------------')
        print(f'final_batch_size .... {final_batch_size}')
        print(f'valid_gpus .......... {valid_gpus}')
